/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <iterator>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>

#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h"
#include "tensorflow/contrib/lite/toco/model.h"
#include "tensorflow/contrib/lite/toco/tooling_util.h"
#include "tensorflow/core/platform/logging.h"

namespace toco {

namespace {

void ComputeConvSizes(const Shape& input_shape, int output_depth, int kwidth,
                      int kheight, int stride_width, int stride_height,
                      PaddingType padding_type, Shape* output_shape,
                      FixedPadding* fixed_padding) {
  const int input_width = input_shape.dims(2);
  const int input_height = input_shape.dims(1);
  const int batch = input_shape.dims(0);

  int output_height = 0;
  int output_width = 0;
  if (padding_type == PaddingType::kValid) {
    output_height = (input_height + stride_height - kheight) / stride_height;
    output_width = (input_width + stride_width - kwidth) / stride_width;
  } else if (padding_type == PaddingType::kSame) {
    output_height = (input_height + stride_height - 1) / stride_height;
    output_width = (input_width + stride_width - 1) / stride_width;
  } else {
    LOG(FATAL) << "Only supporting SAME or VALID padding";
  }

  fixed_padding->height = std::max(
      0, ((output_height - 1) * stride_height + kheight - input_height) / 2);
  fixed_padding->width = std::max(
      0, ((output_width - 1) * stride_width + kwidth - input_width) / 2);

  // Actually had to debug a situation where those were negative due to bad
  // propagation of placeholder -1 sizes in TensorFlowReshape.
  CHECK_GT(output_width, 0);
  CHECK_GT(output_height, 0);
  output_shape->ReplaceDims({batch, output_height, output_width, output_depth});
}

void ComputeBinaryOperatorOutputSize(const Shape& input_shape1,
                                     const Shape& input_shape2,
                                     Array* output_array) {
  const int size1 = RequiredBufferSizeForShape(input_shape1);
  const int size2 = RequiredBufferSizeForShape(input_shape2);
  if (size1 > size2) {
    output_array->copy_shape(input_shape1);
  } else if (size2 > size1) {
    output_array->copy_shape(input_shape2);
  } else {
    CHECK_EQ(size1, size2);
    const int dims1 = input_shape1.dimensions_count();
    const int dims2 = input_shape2.dimensions_count();
    if (dims1 >= dims2) {
      output_array->copy_shape(input_shape1);
    } else {
      output_array->copy_shape(input_shape2);
    }
  }
  CHECK(output_array->has_shape());
}

int GetOutputDepthFromWeights(const Model& model, const Operator& op) {
  const string& weights_name = op.inputs[1];
  const auto& weights_shape = model.arrays.at(weights_name)->shape();
  if (op.type == OperatorType::kConv ||
      op.type == OperatorType::kFullyConnected) {
    return weights_shape.dims(0);
  } else if (op.type == OperatorType::kDepthwiseConv) {
    return weights_shape.dims(3);
  } else {
    LOG(FATAL) << "Unhandled operator type";
  }
}

bool EnsureBiasVectorShape(Model* model, Operator* op) {
  const string& weights_name = op->inputs[1];
  const auto& weights_array = *model->arrays[weights_name];
  // Yield until weights shape has been resolved.
  if (!weights_array.has_shape()) {
    return false;
  }

  if (op->inputs.size() < 3) {
    return false;
  }
  auto& bias_array = *model->arrays[op->inputs[2]];
  if (bias_array.has_shape()) {
    return true;
  }

  const int output_depth = GetOutputDepthFromWeights(*model, *op);
  bias_array.copy_shape(Shape({output_depth}));

  auto& float_buffer = bias_array.GetMutableBuffer<ArrayDataType::kFloat>();
  float_buffer.data.resize(output_depth, 0);

  return true;
}

void ProcessConvOperator(Model* model, ConvOperator* op) {
  if (!EnsureBiasVectorShape(model, op)) {
    return;
  }

  const auto& input_array = *model->arrays[op->inputs[0]];
  // Yield until input dims have been resolved.
  if (!input_array.has_shape()) {
    return;
  }
  const auto& input_shape = input_array.shape();
  CHECK_EQ(input_shape.dimensions_count(), 4);

  const auto& weights_array = *model->arrays[op->inputs[1]];
  // Yield until weights dims have been resolved.
  if (!weights_array.has_shape()) {
    return;
  }
  const auto& weights_shape = weights_array.shape();
  CHECK_EQ(weights_shape.dimensions_count(), 4);

  auto& output_array = model->GetArray(op->outputs[0]);
  const int output_depth = weights_shape.dims(0);
  const int kheight = weights_shape.dims(1);
  const int kwidth = weights_shape.dims(2);
  ComputeConvSizes(input_shape, output_depth, kwidth, kheight, op->stride_width,
                   op->stride_height, op->padding.type,
                   output_array.mutable_shape(),
                   &op->padding.GetOrCreateFixedPadding());
  CHECK_EQ(output_array.shape().dimensions_count(), 4);

  // Set im2col array dimensions if there is one.
  if (op->outputs.size() == 2) {
    const auto& output_shape = output_array.shape();
    const int input_depth = weights_shape.dims(3);
    auto& im2col_array = *model->arrays[op->outputs[1]];
    im2col_array.copy_shape(Shape{output_shape.dims(0), output_shape.dims(1),
                                  output_shape.dims(2),
                                  input_depth * kheight * kwidth});
  }
}

void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) {
  if (!EnsureBiasVectorShape(model, op)) {
    return;
  }

  const auto& input_array = *model->arrays[op->inputs[0]];
  // Yield until input dims have been resolved.
  if (!input_array.has_shape()) {
    return;
  }
  const auto& input_shape = input_array.shape();
  CHECK_EQ(input_shape.dimensions_count(), 4);

  const auto& weights_array = *model->arrays[op->inputs[1]];
  // Yield until weights dims have been resolved.
  if (!weights_array.has_shape()) {
    return;
  }
  const auto& weights_shape = weights_array.shape();
  CHECK_EQ(weights_shape.dimensions_count(), 4);

  const string& output_name = op->outputs[0];
  const int input_depth = input_shape.dims(3);
  const int output_depth = weights_shape.dims(3);
  // TensorFlow doesn't define the depth_multiplier value on DepthwiseConv ops,
  // instead it has to be inferred from the weights dims. However, once we are
  // here, weights dims have already been converted to our own internal format,
  // where the multiplier is no longer readily apparent. So instead we get it
  // as the quotient of output and input depths. We only want to do that when
  // depth_multiplier had the zero value: any other value should be checked
  // as done by the next if() below.
  if (!op->depth_multiplier) {
    op->depth_multiplier = output_depth / input_depth;
  }
  QCHECK_EQ(output_depth, input_depth * op->depth_multiplier)
      << "input/output depths and depth_multiplier don't match";

  const int kheight = weights_shape.dims(1);
  const int kwidth = weights_shape.dims(2);
  ComputeConvSizes(input_shape, output_depth, kwidth, kheight, op->stride_width,
                   op->stride_height, op->padding.type,
                   model->GetArray(output_name).mutable_shape(),
                   &op->padding.GetOrCreateFixedPadding());
}

void ProcessDepthToSpaceOperator(Model* model, DepthToSpaceOperator* op) {
  const auto& input_array = *model->arrays[op->inputs[0]];
  // Yield until input dims have been resolved.
  if (!input_array.has_shape()) {
    return;
  }
  const auto& input_shape = input_array.shape();
  CHECK_EQ(input_shape.dimensions_count(), 4);

  const string& output_name = op->outputs[0];
  const int block_size = op->block_size;
  CHECK_NE(block_size, 0) << "Invalid block_size in " << output_name;
  const int batch = input_shape.dims(0);
  const int height = input_shape.dims(1);
  const int width = input_shape.dims(2);
  const int depth = input_shape.dims(3);
  QCHECK_EQ(depth % (block_size * block_size), 0);

  model->GetArray(output_name)
      .copy_shape(Shape({batch, height * block_size, width * block_size,
                         depth / block_size / block_size}));
}

void ProcessSpaceToDepthOperator(Model* model, SpaceToDepthOperator* op) {
  const auto& input_array = *model->arrays[op->inputs[0]];
  // Yield until input dims have been resolved.
  if (!input_array.has_shape()) {
    return;
  }
  const auto& input_shape = input_array.shape();
  CHECK_EQ(input_shape.dimensions_count(), 4);

  const string& output_name = op->outputs[0];
  const int block_size = op->block_size;
  CHECK_NE(block_size, 0) << "Invalid block_size in " << output_name;
  const int batch = input_shape.dims(0);
  const int height = input_shape.dims(1);
  const int width = input_shape.dims(2);
  const int depth = input_shape.dims(3);
  QCHECK_EQ(width % block_size, 0);
  QCHECK_EQ(height % block_size, 0);

  model->GetArray(output_name)
      .copy_shape(Shape({batch, height / block_size, width / block_size,
                         depth * block_size * block_size}));
}

void ProcessFullyConnectedOperator(Model* model, FullyConnectedOperator* op) {
  if (!EnsureBiasVectorShape(model, op)) {
    return;
  }

  const auto& input_array = *model->arrays[op->inputs[0]];
  // Yield until input dims have been resolved.
  if (!input_array.has_shape()) {
    return;
  }
  const auto& input_shape = input_array.shape();
  CHECK_GE(input_shape.dimensions_count(), 1);

  const auto& weights_array = *model->arrays[op->inputs[1]];
  // Yield until weights dims have been resolved.
  if (!weights_array.has_shape()) {
    return;
  }
  const auto& weights_shape = weights_array.shape();

  const int weights_output_depth = weights_shape.dims(0);
  CHECK_EQ(weights_shape.dimensions_count(), 2);

  const int input_overall_size = RequiredBufferSizeForShape(input_shape);
  const int matmul_repeats = input_overall_size / weights_shape.dims(1);
  CHECK_EQ(matmul_repeats * weights_shape.dims(1), input_overall_size);

  auto& output_array = model->GetArray(op->outputs[0]);
  output_array.copy_shape(Shape({matmul_repeats, weights_output_depth}));
}

void ProcessTensorFlowReshapeOperator(Model* model,
                                      TensorFlowReshapeOperator* op) {
  auto& output_array = *model->arrays[op->outputs[0]];
  // Bail if we already have output dims
  if (output_array.has_shape()) {
    return;
  }

  const auto& input_array = *model->arrays[op->inputs[0]];
  // Yield until input dims have been resolved.
  if (!input_array.has_shape()) {
    return;
  }
  const auto& input_shape = input_array.shape();

  const string& shape_name = op->inputs[1];
  auto& shape_array = model->GetArray(shape_name);
  // Yield until the shape is resolved as a constant array
  if (!shape_array.buffer) {
    return;
  }
  CHECK(shape_array.data_type == ArrayDataType::kInt32);
  // shape_data is the raw array of ints describing the shape
  // in the TensorFlow node. We intentionally make a copy here, rather than
  // modify wildcards in-place below, because in some graphs, the same shape
  // array with a wildcard may be referenced from multiple Reshape nodes, where
  // the wildcard needs to resolved to distinct values.
  std::vector<int32> shape_data =
      shape_array.GetBuffer<ArrayDataType::kInt32>().data;
  // The Reshape shape may have a wildcard dim, encoded as -1.
  bool has_wildcard = false;
  int wildcard_index = 0;
  int product_non_wildcard_dims = 1;
  for (int i = 0; i < shape_data.size(); i++) {
    if (shape_data[i] == -1) {
      CHECK(!has_wildcard);
      has_wildcard = true;
      wildcard_index = i;
    } else {
      product_non_wildcard_dims *= shape_data[i];
    }
  }
  const int input_flat_size = RequiredBufferSizeForShape(input_shape);
  if (has_wildcard) {
    shape_data[wildcard_index] = input_flat_size / product_non_wildcard_dims;
  }
  auto& output_shape = *output_array.mutable_shape();
  *output_shape.mutable_dims() = shape_data;
  const int output_flat_size = RequiredBufferSizeForShape(output_shape);
  CHECK_EQ(output_flat_size, input_flat_size);
}

void ProcessSimpleOperator(Model* model, Operator* op) {
  const auto& input_array = *model->arrays[op->inputs[0]];
  // Yield until input dims have been resolved.
  if (!input_array.has_shape()) {
    return;
  }

  const string& output_name = op->outputs[0];
  auto& output_array = *model->arrays[output_name];
  if (output_array.has_shape()) {
    return;
  }

  output_array.copy_shape(input_array.shape());
}

void ProcessSimpleBinaryOperator(Model* model, Operator* op) {
  CHECK_EQ(op->inputs.size(), 2);
  const auto& input0_array = *model->arrays[op->inputs[0]];
  const auto& input1_array = *model->arrays[op->inputs[1]];
  // Yield until input dims have been resolved.
  if (!input0_array.has_shape() || !input1_array.has_shape()) {
    return;
  }
  const string& output_name = op->outputs[0];
  auto& output_array = *model->arrays[output_name];
  ComputeBinaryOperatorOutputSize(input0_array.shape(), input1_array.shape(),
                                  &output_array);
}

bool KeepDims(const Operator& op) {
  switch (op.type) {
    case OperatorType::kTensorFlowMin:
      return static_cast<const TensorFlowMinOperator&>(op).keep_dims;
    case OperatorType::kTensorFlowMax:
      return static_cast<const TensorFlowMaxOperator&>(op).keep_dims;
    case OperatorType::kTensorFlowSum:
      return static_cast<const TensorFlowSumOperator&>(op).keep_dims;
    case OperatorType::kMean:
      return static_cast<const MeanOperator&>(op).keep_dims;
    default:
      LOG(FATAL) << "Not a reduction operator!";
      return false;
  }
}

void ProcessTensorFlowReductionOperator(Model* model, Operator* op) {
  CHECK_LE(op->inputs.size(), 2);
  auto& output_array = *model->arrays[op->outputs[0]];
  if (output_array.has_shape()) {
    return;
  }
  const auto& input_array = *model->arrays[op->inputs[0]];
  if (!input_array.has_shape()) {
    return;
  }
  const auto& input_shape = input_array.shape();
  const bool keep_dims = KeepDims(*op);
  if (op->inputs.size() == 2) {
    // There is a reduction_indices input.
    const auto& reduction_array = *model->arrays[op->inputs[1]];
    if (!reduction_array.buffer) {
      return;
    }
    CHECK(reduction_array.buffer->type == ArrayDataType::kInt32);
    const auto& reduction_array_vals =
        reduction_array.GetBuffer<ArrayDataType::kInt32>().data;
    auto& output_dims = *output_array.mutable_shape()->mutable_dims();
    output_dims.clear();
    for (int i = 0; i < input_shape.dimensions_count(); i++) {
      bool is_reduction_dim = false;
      for (int r : reduction_array_vals) {
        if (i == r) {
          is_reduction_dim = true;
        }
      }
      if (!is_reduction_dim) {
        output_dims.push_back(input_shape.dims(i));
      } else if (keep_dims) {
        output_dims.push_back(1);
      }
    }
  } else {
    // No reduction_indices means complete reduction to a single scalar.
    if (keep_dims) {
      output_array.copy_shape(input_shape);
    } else {
      output_array.copy_shape(Shape({}));
    }
  }
}

void ProcessSliceOperator(Model* model, SliceOperator* op) {
  CHECK_EQ(op->inputs.size(), 3);
  CHECK_EQ(op->outputs.size(), 1);

  // Yield until the Slice params have been resolved.
  if (op->begin.empty()) return;

  // Yield until input dims have been resolved.
  const auto& input_array = *model->arrays[op->inputs[0]];
  if (!input_array.has_shape()) return;
  const Shape& input_shape = input_array.shape();

  auto& output_array = *model->arrays[op->outputs[0]];
  if (output_array.has_shape()) return;

  CHECK_EQ(input_shape.dims().size(), op->size.size());
  CHECK_EQ(op->begin.size(), op->size.size());

  std::vector<int> output_dims;
  for (int i = 0; i < op->begin.size(); ++i) {
    int size = op->size[i];
    if (size == -1) {
      size = input_array.shape().dims(i) - op->begin[i];
    }
    output_dims.push_back(size);
  }

  *output_array.mutable_shape()->mutable_dims() = output_dims;
}

void ProcessReorderAxesOperator(Model* model, ReorderAxesOperator* op) {
  const string& input_name = op->inputs[0];
  const auto& input_array = *model->arrays[input_name];
  // Yield until input dims have been resolved.
  if (!input_array.has_shape()) {
    return;
  }
  const auto& input_shape = input_array.shape();
  const string& output_name = op->outputs[0];
  Shape* output_shape = model->GetArray(output_name).mutable_shape();
  ShuffleDims(input_shape, op->input_axes_order, op->output_axes_order,
              output_shape);
}

void ProcessConcatenationOperator(Model* model, ConcatenationOperator* op) {
  // Yield until input dims have been resolved.
  for (const auto& input_name : op->inputs) {
    auto& input_array = *model->arrays[input_name];
    if (!input_array.has_shape()) {
      return;
    }
  }
  auto& output_array = model->GetArray(op->outputs[0]);
  // Use 0 input as basis for output dimensions.
  const auto& first_input_array = *model->arrays[op->inputs[0]];
  output_array.copy_shape(first_input_array.shape());
  // Determine the concat size, and enfore that all inputs have
  // the same dimensions count.
  int concat_size = 0;
  for (const auto& input_name : op->inputs) {
    auto& input_array = *model->arrays[input_name];
    CHECK(input_array.has_shape());
    if (input_array.shape().dimensions_count() == 0) {
      continue;
    }
    CHECK_EQ(input_array.shape().dimensions_count(),
             output_array.shape().dimensions_count());
    const std::vector<int>& input_dims = input_array.shape().dims();
    CHECK_LT(op->concat_dim, input_dims.size());
    concat_size += input_dims[op->concat_dim];
  }
  // Write out the concat_size on the output array shape.
  auto& output_shape = *output_array.mutable_shape();
  auto& output_dims = *output_shape.mutable_dims();
  CHECK_LT(op->concat_dim, output_shape.dimensions_count());
  output_dims[op->concat_dim] = concat_size;
}

void ProcessTensorFlowSplitOperator(Model* model, TensorFlowSplitOperator* op) {
  CHECK_EQ(op->inputs.size(), 2);
  const string& input_name = op->inputs[1];
  const auto& input_array = *model->arrays[input_name];
  // Yield until input dims have been resolved.
  if (!input_array.has_shape()) {
    return;
  }
  const Shape& input_shape = input_array.shape();

  // This code is slightly suspect.  The TensorFlow docs say that the axis
  // selection defaults to 0, but we are splitting across the final axis.
  const int input_dims_count = input_shape.dimensions_count();
  const int input_depth = input_shape.dims(input_dims_count - 1);
  CHECK_EQ(input_depth % op->num_split, 0);
  const int split_depth = input_depth / op->num_split;

  Shape output_shape = input_shape;
  (*output_shape.mutable_dims())[input_dims_count - 1] = split_depth;

  CHECK_EQ(op->outputs.size(), op->num_split);
  for (const auto& output : op->outputs) {
    model->arrays[output]->copy_shape(output_shape);
  }
}

void ProcessAveragePoolOperator(Model* model, AveragePoolOperator* op) {
  const string& input_name = op->inputs[0];
  const auto& input_array = *model->arrays[input_name];
  // Yield until input dims have been resolved.
  if (!input_array.has_shape()) {
    return;
  }
  const auto& input_shape = input_array.shape();
  CHECK_EQ(input_shape.dimensions_count(), 4);
  const string& output_name = op->outputs[0];
  const int output_depth = input_shape.dims(3);
  ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight,
                   op->stride_width, op->stride_height, op->padding.type,
                   model->GetArray(output_name).mutable_shape(),
                   &op->padding.GetOrCreateFixedPadding());
}

void ProcessMaxPoolOperator(Model* model, MaxPoolOperator* op) {
  const string& input_name = op->inputs[0];
  const auto& input_array = *model->arrays[input_name];
  // Yield until input dims have been resolved.
  if (!input_array.has_shape()) {
    return;
  }
  const auto& input_shape = input_array.shape();
  CHECK_EQ(input_shape.dimensions_count(), 4);
  const string& output_name = op->outputs[0];
  const int output_depth = input_shape.dims(3);
  ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight,
                   op->stride_width, op->stride_height, op->padding.type,
                   model->GetArray(output_name).mutable_shape(),
                   &op->padding.GetOrCreateFixedPadding());
}

void ProcessL2PoolOperator(Model* model, L2PoolOperator* op) {
  const string& input_name = op->inputs[0];
  const auto& input_array = *model->arrays[input_name];
  // Yield until input dims have been resolved.
  if (!input_array.has_shape()) {
    return;
  }
  const auto& input_shape = input_array.shape();
  if (input_shape.dimensions_count() < 4) {
    LOG(FATAL) << "missing dimensions for " << input_name;
  }
  const string& output_name = op->outputs[0];
  const int output_depth = input_shape.dims(3);
  ComputeConvSizes(input_shape, output_depth, op->kwidth, op->kheight,
                   op->stride_width, op->stride_height, op->padding.type,
                   model->GetArray(output_name).mutable_shape(),
                   &op->padding.GetOrCreateFixedPadding());
}

void ProcessResizeBilinearOperator(Model* model, ResizeBilinearOperator* op) {
  CHECK_EQ(op->inputs.size(), 2);
  CHECK_EQ(op->outputs.size(), 1);

  if (!model->arrays[op->inputs[0]]->has_shape() ||
      !model->arrays[op->inputs[1]]->has_shape()) {
    return;
  }
  const auto& input_data_shape = model->arrays[op->inputs[0]]->shape();

  const string& output_size_name = op->inputs[1];
  const auto& output_size_array = *model->arrays[output_size_name];
  CHECK(output_size_array.data_type == ArrayDataType::kInt32);
  CHECK(output_size_array.has_shape());
  const auto& output_size_shape = output_size_array.shape();
  CHECK_EQ(output_size_shape.dimensions_count(), 1);
  CHECK_EQ(output_size_shape.dims(0), 2);
  std::vector<int32> output_shape =
      output_size_array.GetBuffer<ArrayDataType::kInt32>().data;
  model->arrays[op->outputs[0]]->copy_shape(
      Shape({input_data_shape.dims(0), output_shape[0], output_shape[1],
             input_data_shape.dims(3)}));
}

void ProcessLstmCellOperator(Model* model, LstmCellOperator* op) {
  // I/O arrays should be allocated on creation of op.
  QCHECK_EQ(op->inputs.size(), LstmCellOperator::NUM_INPUTS);
  QCHECK_EQ(op->outputs.size(), LstmCellOperator::NUM_OUTPUTS);

  const auto& input_array =
      *model->arrays[op->inputs[LstmCellOperator::DATA_INPUT]];
  // Yield until all input dims have been resolved.
  if (!input_array.has_shape()) {
    return;
  }
  const auto& input_shape = input_array.shape();
  CHECK_GE(input_shape.dimensions_count(), 2);

  const auto& prev_activ_array =
      *model->arrays[op->inputs[LstmCellOperator::PREV_ACTIV_INPUT]];
  // Yield until all input dims have been resolved.
  if (!prev_activ_array.has_shape()) {
    return;
  }
  const auto& prev_activ_shape = prev_activ_array.shape();
  CHECK_GE(prev_activ_shape.dimensions_count(), 2);

  const auto& weights_array =
      *model->arrays[op->inputs[LstmCellOperator::WEIGHTS_INPUT]];
  // Yield until weights dims have been resolved.
  if (!weights_array.has_shape()) {
    return;
  }
  const auto& weights_shape = weights_array.shape();
  CHECK_EQ(weights_shape.dimensions_count(), 2);

  const auto& bias_array =
      *model->arrays[op->inputs[LstmCellOperator::BIASES_INPUT]];
  // Yield until bias dims have been resolved.
  if (!bias_array.has_shape()) {
    return;
  }
  const auto& bias_shape = bias_array.shape();
  CHECK_GE(bias_shape.dimensions_count(), 1);

  const auto& prev_state_array =
      *model->arrays[op->inputs[LstmCellOperator::PREV_STATE_INPUT]];
  // Yield until all input dims have been resolved.
  if (!prev_state_array.has_shape()) {
    return;
  }
  const auto& prev_state_shape = prev_state_array.shape();
  CHECK_GE(prev_state_shape.dimensions_count(), 2);

  const int fc_output_depth = weights_shape.dims(0);
  CHECK_EQ(fc_output_depth, bias_shape.dims(0));
  CHECK_EQ(fc_output_depth % 4, 0);
  const int depth = fc_output_depth / 4;

  const int input_depth = input_shape.dims(input_shape.dimensions_count() - 1);
  const int fc_input_depth = weights_shape.dims(1);
  CHECK_EQ(input_depth + depth, fc_input_depth);
  Shape output_shape(input_shape);
  (*output_shape.mutable_dims())[output_shape.dimensions_count() - 1] = depth;

  // Set output dimensions
  model->GetArray(op->outputs[LstmCellOperator::STATE_OUTPUT])
      .copy_shape(output_shape);
  model->GetArray(op->outputs[LstmCellOperator::ACTIV_OUTPUT])
      .copy_shape(output_shape);

  Shape concat_temp_shape(input_shape);
  (*concat_temp_shape
        .mutable_dims())[concat_temp_shape.dimensions_count() - 1] =
      fc_input_depth;
  model->GetArray(op->outputs[LstmCellOperator::CONCAT_TEMP])
      .copy_shape(concat_temp_shape);

  Shape activ_temp_shape(input_shape);
  (*activ_temp_shape.mutable_dims())[activ_temp_shape.dimensions_count() - 1] =
      fc_output_depth;
  model->GetArray(op->outputs[LstmCellOperator::ACTIV_TEMP])
      .copy_shape(activ_temp_shape);
}

void ProcessSpaceToBatchNDOperator(Model* model, SpaceToBatchNDOperator* op) {
  const auto& input_array = *model->arrays[op->inputs[0]];
  // Yield until input dims have been resolved.
  if (!input_array.has_shape()) {
    return;
  }
  const auto& input_shape = input_array.shape();
  CHECK_EQ(input_shape.dimensions_count(), 4);
  const auto input_height = input_shape.dims(1);
  const auto input_width = input_shape.dims(2);

  const auto& block_shape_array = *model->arrays[op->inputs[1]];
  const auto& paddings_array = *model->arrays[op->inputs[2]];
  const auto& block_shape_array_shape = block_shape_array.shape();
  const auto& paddings_array_shape = paddings_array.shape();
  QCHECK_EQ(block_shape_array_shape.dimensions_count(), 1);
  QCHECK_EQ(paddings_array_shape.dimensions_count(), 2);

  // We only support two dimensions.
  QCHECK_EQ(block_shape_array_shape.dims(0), 2);
  if (!block_shape_array.buffer) {
    return;
  }
  QCHECK(block_shape_array.data_type == ArrayDataType::kInt32);
  const auto& block_shape_data =
      block_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
  auto block_height = block_shape_data[0];
  auto block_width = block_shape_data[1];

  QCHECK_EQ(paddings_array_shape.dims(0), 2);  // Number of block dimensions
  QCHECK_EQ(paddings_array_shape.dims(1), 2);  // Two parameters per dimension.
  if (!paddings_array.buffer) {
    return;
  }
  QCHECK(paddings_array.data_type == ArrayDataType::kInt32);
  const auto& paddings_data =
      paddings_array.GetBuffer<ArrayDataType::kInt32>().data;
  int height_with_paddings = input_height + paddings_data[0] + paddings_data[1];
  int width_with_paddings = input_width + paddings_data[2] + paddings_data[3];
  QCHECK_EQ(height_with_paddings % block_height, 0);
  QCHECK_EQ(width_with_paddings % block_width, 0);
  int output_height = height_with_paddings / block_height;
  int output_width = width_with_paddings / block_width;

  model->arrays[op->outputs[0]]->copy_shape(
      Shape({input_shape.dims(0) * block_height * block_width, output_height,
             output_width, input_shape.dims(3)}));
}

void ProcessBatchToSpaceNDOperator(Model* model, BatchToSpaceNDOperator* op) {
  const auto& input_array = *model->arrays[op->inputs[0]];
  // Yield until input dims have been resolved.
  if (!input_array.has_shape()) {
    return;
  }
  const auto& input_shape = input_array.shape();
  CHECK_EQ(input_shape.dimensions_count(), 4);
  const auto input_height = input_shape.dims(1);
  const auto input_width = input_shape.dims(2);

  const auto& block_shape_array = *model->arrays[op->inputs[1]];
  const auto& crops_array = *model->arrays[op->inputs[2]];
  const auto& block_shape_array_shape = block_shape_array.shape();
  const auto& crops_array_shape = crops_array.shape();
  QCHECK_EQ(block_shape_array_shape.dimensions_count(), 1);
  QCHECK_EQ(crops_array_shape.dimensions_count(), 2);

  // We only support two dimensions.
  QCHECK_EQ(block_shape_array_shape.dims(0), 2);
  if (!block_shape_array.buffer) {
    return;
  }
  QCHECK(block_shape_array.data_type == ArrayDataType::kInt32);
  const auto& block_shape_data =
      block_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
  auto block_height = block_shape_data[0];
  auto block_width = block_shape_data[1];

  QCHECK_EQ(crops_array_shape.dims(0), 2);  // Number of block dimensions
  QCHECK_EQ(crops_array_shape.dims(1), 2);  // Two parameters per dimension.
  if (!crops_array.buffer) {
    return;
  }
  QCHECK(crops_array.data_type == ArrayDataType::kInt32);
  const auto& crops_data = crops_array.GetBuffer<ArrayDataType::kInt32>().data;
  // We don't support crops now.
  QCHECK_EQ(crops_data[0], 0);
  QCHECK_EQ(crops_data[1], 0);
  QCHECK_EQ(crops_data[2], 0);
  QCHECK_EQ(crops_data[3], 0);

  QCHECK_EQ(input_shape.dims(0) % (block_height * block_width), 0);

  int output_height = input_height * block_height;
  int output_width = input_width * block_width;

  model->arrays[op->outputs[0]]->copy_shape(
      Shape({input_shape.dims(0) / (block_height * block_width), output_height,
             output_width, input_shape.dims(3)}));
}

void ProcessGatherOperator(Model* model, GatherOperator* op) {
  const auto& input_array = *model->arrays[op->inputs[0]];
  const auto& indices_array = *model->arrays[op->inputs[1]];
  auto& output_array = *model->arrays[op->outputs[0]];

  // Bail if we already know the output shape.
  if (output_array.has_shape()) {
    return;
  }

  // Yield until input dims have been resolved.
  if (!input_array.has_shape() || !indices_array.has_shape()) {
    return;
  }

  const auto& input_shape = input_array.shape();
  const auto& indices_shape = indices_array.shape();
  QCHECK_GE(input_shape.dimensions_count(), 1);
  op->input_rank = input_shape.dimensions_count();

  // We only support 1-D indices.
  QCHECK_EQ(indices_shape.dimensions_count(), 1);

  // Copy the input dimensions to the output except for dimension 0,
  // where the dimension of indices_shape is used.
  auto output_dims = output_array.mutable_shape()->mutable_dims();
  output_dims->push_back(indices_shape.dims(0));
  for (int dim = 1; dim < input_shape.dimensions_count(); dim++) {
    output_dims->push_back(input_shape.dims(dim));
  }
}

void ProcessPadOperator(Model* model, PadOperator* op) {
  CHECK_EQ(op->inputs.size(), 2);
  CHECK_EQ(op->outputs.size(), 1);

  const auto& input_array = *model->arrays[op->inputs[0]];

  // Yield until input dims have been resolved.
  if (!input_array.has_shape()) return;

  if (op->left_padding.empty()) return;
  CHECK_EQ(op->left_padding.size(), op->right_padding.size());

  auto& output_array = *model->arrays[op->outputs[0]];
  if (output_array.has_shape()) return;

  Shape output_shape = input_array.shape();
  std::vector<int>& dims = *output_shape.mutable_dims();
  CHECK_EQ(op->left_padding.size(), dims.size());

  for (int i = 0; i < op->left_padding.size(); ++i) {
    dims[i] += op->left_padding[i] + op->right_padding[i];
  }

  output_array.copy_shape(output_shape);
}

void ProcessStridedSliceOperator(Model* model, StridedSliceOperator* op) {
  CHECK_EQ(op->inputs.size(), 4);
  CHECK_EQ(op->outputs.size(), 1);

  const auto& input_array = *model->arrays[op->inputs[0]];

  // Yield until input dims have been resolved.
  if (!input_array.has_shape()) return;

  if (op->start_indices.empty()) return;
  CHECK_EQ(op->start_indices.size(), op->stop_indices.size());
  CHECK_EQ(op->start_indices.size(), op->strides.size());

  auto& output_array = *model->arrays[op->outputs[0]];
  if (output_array.has_shape()) return;

  Shape output_shape = input_array.shape();
  std::vector<int>& dims = *output_shape.mutable_dims();
  CHECK_EQ(op->start_indices.size(), dims.size());

  for (int i = 0; i < op->start_indices.size(); ++i) {
    const int mask = 1 << i;
    const int start = (op->begin_mask & mask) ? 0 : op->start_indices[i];
    const int stop = (op->end_mask & mask) ? input_array.shape().dims()[i]
                                           : op->stop_indices[i];
    dims[i] = (stop - start) / op->strides[i];
  }

  output_array.copy_shape(output_shape);
}

void ProcessSqueezeOperator(Model* model, SqueezeOperator* op) {
  CHECK_EQ(op->inputs.size(), 1);
  CHECK_EQ(op->outputs.size(), 1);

  const auto& input_array = *model->arrays[op->inputs[0]];

  // Yield until input dims have been resolved.
  if (!input_array.has_shape()) return;

  auto& output_array = *model->arrays[op->outputs[0]];
  if (output_array.has_shape()) return;

  const std::vector<int>& input_dims = input_array.shape().dims();
  std::vector<int> output_dims;

  for (int i = 0; i < input_dims.size(); ++i) {
    if (input_dims[i] != 1 ||
        (!op->squeeze_dims.empty() &&
         std::find(op->squeeze_dims.begin(), op->squeeze_dims.end(), i) ==
             op->squeeze_dims.end())) {
      output_dims.push_back(input_dims[i]);
    }
  }
  *output_array.mutable_shape()->mutable_dims() = output_dims;
}

void ProcessSvdfOperator(Model* model, SvdfOperator* op) {
  CHECK(op->inputs.size() == 3 || op->inputs.size() == 4);
  const auto& input_array = *model->arrays[op->inputs[0]];
  if (!input_array.has_shape()) return;

  auto& weights_feature_array = *model->arrays[op->inputs[1]];
  if (!weights_feature_array.has_shape()) return;

  const auto& weights_time_array = *model->arrays[op->inputs[2]];
  if (!weights_time_array.has_shape()) return;

  const bool has_bias = (op->inputs.size() == 4);
  if (has_bias) {
    const auto& bias_array = *model->arrays[op->inputs[3]];
    if (!bias_array.has_shape()) return;
  }

  const int batch_size = input_array.shape().dims()[0];
  const int num_units = weights_feature_array.shape().dims()[0];
  const int memory_size = weights_time_array.shape().dims()[1];

  auto& state_array = model->GetArray(op->outputs[0]);
  state_array.mutable_shape()->ReplaceDims(
      {batch_size, memory_size * num_units});

  auto& output_array = model->GetArray(op->outputs[1]);
  output_array.mutable_shape()->ReplaceDims({batch_size, num_units});
}
}  // namespace

bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
  auto it = model->operators.begin() + op_index;
  auto* op = it->get();
  std::unordered_map<string, std::vector<int>> old_output_dims;
  for (const auto& output : op->outputs) {
    if (model->arrays[output]->has_shape()) {
      old_output_dims[output] = model->arrays[output]->shape().dims();
    }
  }

  switch (op->type) {
    case OperatorType::kBatchNormalization:
    case OperatorType::kL2Normalization:
    case OperatorType::kDequantize:
    case OperatorType::kRelu:
    case OperatorType::kRelu1:
    case OperatorType::kRelu6:
    case OperatorType::kSoftmax:
    case OperatorType::kLogistic:
    case OperatorType::kTanh:
    case OperatorType::kLocalResponseNormalization:
    case OperatorType::kTensorFlowIdentity:
    case OperatorType::kFakeQuant:
    case OperatorType::kTensorFlowRsqrt:
    case OperatorType::kTensorFlowSqrt:
    case OperatorType::kTensorFlowSquare:
    case OperatorType::kTensorFlowAll:
    case OperatorType::kTensorFlowAssert:
    case OperatorType::kCast:
    case OperatorType::kFloor:
      ProcessSimpleOperator(model, op);
      break;
    case OperatorType::kGather:
      ProcessGatherOperator(model, static_cast<GatherOperator*>(op));
      break;

    case OperatorType::kAdd:
    case OperatorType::kSub:
    case OperatorType::kMul:
    case OperatorType::kDiv:
    case OperatorType::kTensorFlowLess:
    case OperatorType::kTensorFlowLessEqual:
    case OperatorType::kTensorFlowGreater:
    case OperatorType::kTensorFlowMaximum:
    case OperatorType::kTensorFlowMinimum:
    case OperatorType::kTensorFlowGreaterEqual:
      ProcessSimpleBinaryOperator(model, op);
      break;
    case OperatorType::kConv:
      ProcessConvOperator(model, static_cast<ConvOperator*>(op));
      break;
    case OperatorType::kDepthwiseConv:
      ProcessDepthwiseConvOperator(model,
                                   static_cast<DepthwiseConvOperator*>(op));
      break;
    case OperatorType::kDepthToSpace:
      ProcessDepthToSpaceOperator(model,
                                  static_cast<DepthToSpaceOperator*>(op));
      break;
    case OperatorType::kSpaceToDepth:
      ProcessSpaceToDepthOperator(model,
                                  static_cast<SpaceToDepthOperator*>(op));
      break;
    case OperatorType::kFullyConnected:
      ProcessFullyConnectedOperator(model,
                                    static_cast<FullyConnectedOperator*>(op));
      break;
    case OperatorType::kTensorFlowReshape:
      ProcessTensorFlowReshapeOperator(
          model, static_cast<TensorFlowReshapeOperator*>(op));
      break;
    case OperatorType::kAveragePool:
      ProcessAveragePoolOperator(model, static_cast<AveragePoolOperator*>(op));
      break;
    case OperatorType::kMaxPool:
      ProcessMaxPoolOperator(model, static_cast<MaxPoolOperator*>(op));
      break;
    case OperatorType::kL2Pool:
      ProcessL2PoolOperator(model, static_cast<L2PoolOperator*>(op));
      break;
    case OperatorType::kTensorFlowMin:
    case OperatorType::kTensorFlowMax:
    case OperatorType::kTensorFlowSum:
    case OperatorType::kMean:
      ProcessTensorFlowReductionOperator(model, op);
      break;

    case OperatorType::kSlice:
      ProcessSliceOperator(model, static_cast<SliceOperator*>(op));
      break;

    case OperatorType::kTensorFlowTile:
      // We don't currently implement the propagation of fixed sizes through
      // a TensorFlow Tile.
      //
      // Fortunately, we don't need to: so far, we have only dealt with Tile
      // or Slice ops in subgraphs that are identified as L2Normalization.
      // See IdentifyL2Normalization.
      break;
    case OperatorType::kTensorFlowSwitch:
      // We can't know the sizes of the outputs until we have resolved the
      // predicate, and once we have resolved the predicate, the whole
      // Switch node will get resolved away.
      // See ResolveTensorFlowSwitch.
      break;
    case OperatorType::kTensorFlowMerge:
      // No need to bother resolving TensorFlow Merge ops: other graph
      // transformations will remove them anyway.
      // See ResolveTensorFlowMerge.
      break;
    case OperatorType::kTensorFlowSplit:
      ProcessTensorFlowSplitOperator(model,
                                     static_cast<TensorFlowSplitOperator*>(op));
      break;
    case OperatorType::kSqueeze:
      ProcessSqueezeOperator(model, static_cast<SqueezeOperator*>(op));
      break;
    case OperatorType::kTensorFlowConcat:
    case OperatorType::kTensorFlowConcatV2:
      // Unimplemented, hopefully another graph transformation will
      // drop it or rewrite it. Concretely, either ResolveTensorFlowConcat
      // will resolve this node to a DepthConcatenation, or else we have
      // a more general non-depth concatenation that will hopefully be dropped,
      // or else at the moment we will abort.
      break;
    case OperatorType::kTensorFlowShape:
      // Unimplemented, hopefully another graph transformation will drop it or
      // rewrite it.
      break;
    case OperatorType::kReorderAxes:
      ProcessReorderAxesOperator(model, static_cast<ReorderAxesOperator*>(op));
      break;
    case OperatorType::kConcatenation:
      ProcessConcatenationOperator(model,
                                   static_cast<ConcatenationOperator*>(op));
      break;
    case OperatorType::kResizeBilinear:
      ProcessResizeBilinearOperator(model,
                                    static_cast<ResizeBilinearOperator*>(op));
      break;
    case OperatorType::kLstmCell:
      ProcessLstmCellOperator(model, static_cast<LstmCellOperator*>(op));
      break;
    case OperatorType::kTensorFlowMatMul:
      // MatMul operators are converted to FullyConnected, after which their
      // shapes are propagated.
      break;
    case OperatorType::kSpaceToBatchND:
      ProcessSpaceToBatchNDOperator(model,
                                    static_cast<SpaceToBatchNDOperator*>(op));
      break;
    case OperatorType::kBatchToSpaceND:
      ProcessBatchToSpaceNDOperator(model,
                                    static_cast<BatchToSpaceNDOperator*>(op));
      break;
    case OperatorType::kPad:
      ProcessPadOperator(model, static_cast<PadOperator*>(op));
      break;
    case OperatorType::kStridedSlice:
      ProcessStridedSliceOperator(model,
                                  static_cast<StridedSliceOperator*>(op));
      break;
    case OperatorType::kTensorFlowUnsupported:
      break;
    case OperatorType::kSvdf:
      ProcessSvdfOperator(model, static_cast<SvdfOperator*>(op));
      break;
    default:
      // Unimplemented, another graph transformation should drop it.
      LOG(FATAL) << "Unhandled operator type " << OperatorTypeName(op->type);
  }

  // Return true if any output dim changed, false if none changed.
  // Assumption: no transformation clears an output shape, they only add shapes.
  for (const auto& output : op->outputs) {
    if (model->arrays[output]->has_shape() &&
        (old_output_dims[output] != model->arrays[output]->shape().dims())) {
      return true;
    }
  }
  return false;
}

}  // namespace toco
